Skip to content

[CK] Fix MoE 2-stage dispatch for non-128-divisible inter_dim#3973

Open
jonahbernard wants to merge 1 commit into
ROCm:mainfrom
jonahbernard:gelu-tanh-moe-dim-fix
Open

[CK] Fix MoE 2-stage dispatch for non-128-divisible inter_dim#3973
jonahbernard wants to merge 1 commit into
ROCm:mainfrom
jonahbernard:gelu-tanh-moe-dim-fix

Conversation

@jonahbernard

Copy link
Copy Markdown

Motivation

The gfx950 heuristic dispatch sent all inter_dim > 192 shapes to the NPerBlock/KPerBlock=128 fast path, which fails CK's N%NPerBlock (stage1) and K%KPerBlock (stage2) divisibility checks when inter_dim is not a multiple of 128 (e.g. DiffusionGemma moe_inter=704=64*11). Route those shapes to the PerBlock=64 instances, which divide any multiple of 64.

Technical Details

Widen pre-existing dispatch heuristic if (inter_dim <= 192) to if (inter_dim <= 192 || inter_dim % 128 != 0).

Test Plan

Verified on gfx950 with the no-quant (a16w16) legacy CK 2-stage path at inter_dim=704, sweeping token counts from 1 to 163840 (covering all block_m selections: 32 / 64 / 128):

python3 op_tests/test_moe_2stage.py --no-flydsl-csv -q 0 -dim 7168,704

Test Result

  ┌────────┬─────────┬─────────────┐
  │ token  │ block_m │ logits_diff │
  ├────────┼─────────┼─────────────┤
  │      1 │      32 │    9.53e-06 │
  ├────────┼─────────┼─────────────┤
  │      3 │      32 │    9.13e-06 │
  ├────────┼─────────┼─────────────┤
  │      5 │      32 │    1.02e-05 │
  ├────────┼─────────┼─────────────┤
  │     16 │      32 │    1.04e-05 │
  ├────────┼─────────┼─────────────┤
  │     32 │      32 │    1.03e-05 │
  ├────────┼─────────┼─────────────┤
  │     64 │      32 │    1.03e-05 │
  ├────────┼─────────┼─────────────┤
  │    128 │      32 │    1.01e-05 │
  ├────────┼─────────┼─────────────┤
  │    256 │      64 │    1.01e-05 │
  ├────────┼─────────┼─────────────┤
  │   1024 │     128 │    1.01e-05 │
  ├────────┼─────────┼─────────────┤
  │   4096 │     128 │    1.01e-05 │
  ├────────┼─────────┼─────────────┤
  │   8192 │     128 │    1.01e-05 │
  ├────────┼─────────┼─────────────┤
  │ 163840 │     128 │    1.01e-05 │
  └────────┴─────────┴─────────────┘

Submission Checklist

…ances

The gfx950 heuristic dispatch sent all inter_dim > 192 shapes to the
NPerBlock/KPerBlock=128 fast path, which fails CK's N%NPerBlock (stage1)
and K%KPerBlock (stage2) divisibility checks when inter_dim is not a
multiple of 128 (e.g. DiffusionGemma moe_inter=704=64*11). Route those
shapes to the PerBlock=64 instances, which divide any multiple of 64.
@jonahbernard jonahbernard requested a review from a team June 27, 2026 22:40
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3973 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant